/**
 * Copyright (c) 2024 Huawei Technologies Co., Ltd.
 * This file is a part of the CANN Open Software.
 * Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
 * Please refer to the License for details. You may not use this file except in compliance with the License.
 * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
 * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
 * See LICENSE in the root of the software repository for the full text of the License.
 */

#ifndef EXAMPLES_NORMALIZATION_LAYERNORM_CUSTOM_H
#define EXAMPLES_NORMALIZATION_LAYERNORM_CUSTOM_H
#include "kernel_operator.h"

namespace MyCustomKernel {
struct VecTiling {
    LayerNormTiling layernormTilingData;
    float epsilon = 0;
};

template <bool isReuseSource = false> class KernelLayernorm {
public:
    __aicore__ inline KernelLayernorm() {}
    __aicore__ inline void Init(GM_ADDR inputXGm, GM_ADDR gammGm, GM_ADDR betaGm, GM_ADDR outputGm,
        GM_ADDR outputMeanGm, GM_ADDR outputVarianceGm, VecTiling tilingData)
    {
        this->epsilon = tilingData.epsilon;
        tiling_ = tilingData.layernormTilingData;
        this->bLength = tiling_.bLength;
        this->sLength = tiling_.sLength;
        this->hLength = tiling_.hLength;

        bshLength = bLength * sLength * hLength;
        bsLength = bLength * sLength;

        inputXGlobal.SetGlobalBuffer(reinterpret_cast<__gm__ float *>(inputXGm), bshLength);
        gammGlobal.SetGlobalBuffer(reinterpret_cast<__gm__ float *>(gammGm), hLength);
        betaGlobal.SetGlobalBuffer(reinterpret_cast<__gm__ float *>(betaGm), hLength);

        outputGlobal.SetGlobalBuffer(reinterpret_cast<__gm__ float *>(outputGm), bshLength);
        outputMeanGlobal.SetGlobalBuffer(reinterpret_cast<__gm__ float *>(outputMeanGm), bsLength);
        outputVarianceGlobal.SetGlobalBuffer(reinterpret_cast<__gm__ float *>(outputVarianceGm), bsLength);

        pipe.InitBuffer(inQueueX, 1, sizeof(float) * bshLength);
        pipe.InitBuffer(inQueueGamma, 1, sizeof(float) * hLength);
        pipe.InitBuffer(inQueueBeta, 1, sizeof(float) * hLength);
        pipe.InitBuffer(outQueue, 1, sizeof(float) * bshLength);
        pipe.InitBuffer(outQueueMean, 1, sizeof(float) * bsLength);
        pipe.InitBuffer(outQueueVariance, 1, sizeof(float) * bsLength);
    }
    __aicore__ inline void Process()
    {
        CopyIn();
        Compute();
        CopyOut();
    }

private:
    __aicore__ inline void CopyIn()
    {
        AscendC::LocalTensor<float> inputXLocal = inQueueX.AllocTensor<float>();
        AscendC::LocalTensor<float> gammaLocal = inQueueGamma.AllocTensor<float>();
        AscendC::LocalTensor<float> betaLocal = inQueueBeta.AllocTensor<float>();

        AscendC::DataCopy(inputXLocal, inputXGlobal, bshLength);
        AscendC::DataCopy(gammaLocal, gammGlobal, hLength);
        AscendC::DataCopy(betaLocal, betaGlobal, hLength);

        inQueueX.EnQue(inputXLocal);
        inQueueGamma.EnQue(gammaLocal);
        inQueueBeta.EnQue(betaLocal);
    }
    __aicore__ inline void Compute()
    {
        AscendC::LocalTensor<float> inputXLocal = inQueueX.DeQue<float>();
        AscendC::LocalTensor<float> gammaLocal = inQueueGamma.DeQue<float>();
        AscendC::LocalTensor<float> betaLocal = inQueueBeta.DeQue<float>();

        AscendC::LocalTensor<float> outputLocal = outQueue.AllocTensor<float>();
        AscendC::LocalTensor<float> meanLocal = outQueueMean.AllocTensor<float>();
        AscendC::LocalTensor<float> varianceLocal = outQueueVariance.AllocTensor<float>();

        AscendC::LayerNorm<float, isReuseSource>(outputLocal, meanLocal, varianceLocal, inputXLocal, gammaLocal, betaLocal,
            (float)epsilon, tiling_);

        outQueue.EnQue<float>(outputLocal);
        outQueueMean.EnQue<float>(meanLocal);
        outQueueVariance.EnQue<float>(varianceLocal);

        inQueueX.FreeTensor(inputXLocal);
        inQueueGamma.FreeTensor(gammaLocal);
        inQueueBeta.FreeTensor(betaLocal);
    }
    __aicore__ inline void CopyOut()
    {
        AscendC::LocalTensor<float> outputLocal = outQueue.DeQue<float>();
        AscendC::LocalTensor<float> meanLocal = outQueueMean.DeQue<float>();
        AscendC::LocalTensor<float> varianceLocal = outQueueVariance.DeQue<float>();

        AscendC::DataCopy(outputGlobal, outputLocal, bshLength);
        AscendC::DataCopy(outputMeanGlobal, meanLocal, bsLength);
        AscendC::DataCopy(outputVarianceGlobal, varianceLocal, bsLength);

        outQueue.FreeTensor(outputLocal);
        outQueueMean.FreeTensor(meanLocal);
        outQueueVariance.FreeTensor(varianceLocal);
    }

private:
    AscendC::GlobalTensor<float> inputXGlobal;
    AscendC::GlobalTensor<float> gammGlobal;
    AscendC::GlobalTensor<float> betaGlobal;
    AscendC::GlobalTensor<float> outputGlobal;
    AscendC::GlobalTensor<float> outputMeanGlobal;
    AscendC::GlobalTensor<float> outputVarianceGlobal;

    AscendC::TPipe pipe;
    AscendC::TQue<AscendC::QuePosition::VECIN, 1> inQueueX;
    AscendC::TQue<AscendC::QuePosition::VECIN, 1> inQueueGamma;
    AscendC::TQue<AscendC::QuePosition::VECIN, 1> inQueueBeta;
    AscendC::TQue<AscendC::QuePosition::VECOUT, 1> outQueue;
    AscendC::TQue<AscendC::QuePosition::VECOUT, 1> outQueueMean;
    AscendC::TQue<AscendC::QuePosition::VECOUT, 1> outQueueVariance;

    uint32_t bLength;
    uint32_t sLength;
    uint32_t hLength;
    float epsilon;
    LayerNormTiling tiling_;

    uint32_t bshLength;
    uint32_t bsLength;
};
}
#endif // EXAMPLES_NORMALIZATION_LAYERNORM_CUSTOM_H